# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.

# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


import tensorflow as tf

tf.compat.v1.enable_eager_execution()


def _round_up_tf(x, multiple):
    # Tf version of remainder = x % multiple
    remainder = tf.mod(x, multiple)
    # Tf version of return x if remainder == 0 else x + multiple, -remainder
    x_round = tf.cond(
        tf.equal(remainder, tf.zeros(tf.shape(remainder), dtype=tf.int32)),
        lambda: x,
        lambda: x + multiple,
        -remainder,
    )

    return x_round


def sequence_mask(lengths, r, expand=True):
    """Returns a 2-D or 3-D tensorflow sequence mask depending on the argument 'expand'"""
    max_len = tf.reduce_max(lengths)
    max_len = _round_up_tf(max_len, tf.convert_to_tensor(r))
    if expand:
        return tf.expand_dims(
            tf.sequence_mask(lengths, maxlen=max_len, dtype=tf.float32), axis=-1
        )
    return tf.sequence_mask(lengths, maxlen=max_len, dtype=tf.float32)


def MaskedSigmoidCrossEntropy(
    targets,
    outputs,
    targets_lengths,
    outputs_per_step,
    cross_entropy_pos_weight,
    mask=None,
):
    """Computes a masked SigmoidCrossEntropy with logits"""

    # [batch_size, time_dimension]
    # example:
    # sequence_mask([1, 3, 2], 5) = [[1., 0., 0., 0., 0.],
    #                                [1., 1., 1., 0., 0.],
    #                                [1., 1., 0., 0., 0.]]
    # Note the maxlen argument that ensures mask shape is compatible with r>1
    # This will by default mask the extra paddings caused by r>1
    if mask is None:
        mask = sequence_mask(targets_lengths, outputs_per_step, False)

    with tf.control_dependencies([tf.assert_equal(tf.shape(targets), tf.shape(mask))]):
        # Use a weighted sigmoid cross entropy to measure the <stop_token> loss. Set hparams.cross_entropy_pos_weight to 1
        # will have the same effect as  vanilla tf.nn.sigmoid_cross_entropy_with_logits.
        losses = tf.nn.weighted_cross_entropy_with_logits(
            targets=targets, logits=outputs, pos_weight=cross_entropy_pos_weight
        )

    with tf.control_dependencies([tf.assert_equal(tf.shape(mask), tf.shape(losses))]):
        masked_loss = losses * mask

    return tf.reduce_sum(masked_loss) / tf.count_nonzero(masked_loss, dtype=tf.float32)


stop_token_prediction = tf.Variable(
    [
        [
            -19.2963486,
            -19.260931,
            -20.0291309,
            -23.1824493,
            -22.9772644,
            -20.2885036,
            -27.675539,
            -27.3030777,
            -29.577425,
            -28.1723957,
            -27.8901939,
            -26.384119,
            -27.3650894,
            -27.1803169,
            -23.7659645,
            -27.9101353,
            -27.6226883,
            -28.7287712,
            -26.2662029,
            -26.041851,
            -26.4625721,
            -24.2690239,
            -24.1902237,
            -23.376133,
            -22.6833057,
            -22.6326752,
            -23.0127048,
            -22.7092934,
            -22.6531601,
            -20.4668312,
            -22.9737015,
            -22.949688,
            -19.7955379,
            -23.9654541,
            -24.1125488,
            -16.8582115,
            -24.5958805,
            -24.619318,
            -19.6350899,
            -26.0553,
            -25.7573509,
            -29.0456867,
            -24.1830368,
            -24.0690536,
            -26.6302032,
            -24.7784214,
            -24.6312675,
            -26.5384,
            -22.3738,
            -22.324398,
            -23.8008652,
            -22.1903877,
            -22.2572346,
            -21.1526031,
            -22.0404415,
            -22.1381817,
            -23.1001148,
            -21.6368828,
            -21.7594643,
            -22.5351944,
            -22.5584583,
            -22.7108879,
            -20.525753,
            -23.0564251,
            -23.1634941,
            -22.7847729,
            -24.2631569,
            -24.3629684,
            -25.3400059,
            -22.6183834,
            -22.8448315,
            -23.9382324,
            -23.1020679,
            -23.233799,
            -22.6955643,
            -22.4547825,
            -22.6073112,
            -21.2976532,
            -22.9158459,
            -22.9708233,
            -22.4025,
            -23.5523796,
            -23.5779228,
            -23.5215797,
            -24.3425903,
            -24.4970589,
            -24.7781792,
            -25.6408978,
            -25.9453678,
            -25.3536453,
            -25.6538544,
            -25.8861675,
            -26.3280144,
            -24.8373375,
            -24.8233719,
            -23.5885773,
            -23.9586372,
            -23.6939812,
            -22.1724129,
            -23.5438118,
            -23.3040047,
            -23.8013535,
            -21.5034828,
            -21.3676109,
            -23.6252117,
            -20.5316467,
            -20.62747,
            -22.1351776,
            -23.3784924,
            -23.4007187,
            -23.1962986,
            -23.2673321,
            -23.3588448,
            -20.4417667,
            -23.7708473,
            -23.9409142,
            -22.6762218,
            -23.636425,
            -23.8731575,
            -23.8063,
            -24.4575443,
            -24.6403236,
            -24.1111279,
            -24.729847,
            -24.9316673,
            -23.4876862,
            -25.3459587,
            -25.4803429,
            -23.573637,
            -23.451,
            -23.5613155,
            -22.4887867,
            -23.6182842,
            -23.6641579,
            -22.0665321,
            -24.9837589,
            -25.0286808,
            -22.7693596,
            -23.6526451,
            -23.8223419,
            -22.7516861,
            -23.4104614,
            -23.6487885,
            -23.163723,
            -23.4834728,
            -23.7507381,
            -22.5205746,
            -23.231369,
            -23.4580765,
            -22.335535,
            -23.637701,
            -23.7155457,
            -22.711525,
            -23.7864494,
            -23.7470684,
            -21.1830482,
            -24.4629459,
            -24.4870872,
            -22.1926651,
            -24.3497772,
            -24.3965874,
            -23.1209297,
            -24.1713791,
            -24.2138214,
            -21.8886795,
            -24.6122303,
            -24.8375359,
            -22.7196045,
            -24.7429218,
            -24.939991,
            -22.5758209,
            -25.1342621,
            -25.2911663,
            -22.7056255,
            -25.5367126,
            -25.6215458,
            -24.2621346,
            -24.959343,
            -25.1273861,
            -25.5547142,
            -23.950901,
            -24.1237774,
            -23.0260277,
            -24.1956635,
            -24.4791393,
            -22.6198807,
            -25.1723881,
            -25.4708157,
            -24.9331245,
            -26.1021328,
            -26.4309196,
            -25.2354126,
            -24.9315262,
            -25.1621246,
            -23.9584408,
            -24.5821438,
            -24.8249702,
            -22.8991299,
            -24.406395,
            -24.5078354,
            -21.9084454,
            -24.3823357,
            -24.4428082,
            -21.5360947,
            -24.3317642,
            -24.3363438,
            -21.1170616,
            -24.9145298,
            -24.8127346,
            -22.7880783,
            -24.8024845,
            -24.6873283,
            -22.7920494,
            -24.7602634,
            -24.710001,
            -21.3401356,
            -21.2255783,
            -21.2736683,
            -17.7675343,
            -23.8382912,
            -23.7931099,
            -23.3757706,
            -24.4784184,
            -24.4339848,
            -23.804306,
            -23.7379055,
            -23.6727619,
            -21.8263683,
            -23.5269566,
            -23.4892349,
            -21.7052937,
            -24.1569557,
            -24.0990562,
            -23.29496,
            -23.8534336,
            -23.7737827,
            -22.0034542,
            -24.1009197,
            -24.037159,
            -22.1008091,
            -23.5646248,
            -23.5025635,
            -21.9216061,
            -24.4235058,
            -24.3794708,
            -22.3685799,
            -24.6468735,
            -24.5811653,
            -22.5814056,
            -24.148201,
            -24.0724564,
            -21.335146,
            -23.6016941,
            -23.5277939,
            -21.3815708,
            -24.4757156,
            -24.3964157,
            -22.8095894,
            -24.6821842,
            -24.6291065,
            -22.366869,
            -22.4068909,
            -22.4802837,
            -18.4135742,
        ]
    ],
    dtype=tf.float32,
)

stop_token_target = tf.Variable(
    [
        [
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            1,
        ]
    ],
    dtype=tf.float32,
)

targets_lengths = tf.Variable([264], dtype=tf.int32)

with tf.GradientTape() as tape:
    tape.watch(stop_token_prediction)
    stop_token_loss = MaskedSigmoidCrossEntropy(
        stop_token_target, stop_token_prediction, targets_lengths, 3, 1
    )

print(stop_token_loss)

g = tape.gradient(stop_token_loss, stop_token_prediction)

print(g)
